from abc import ABCMeta, abstractmethod
from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed
import torch.nn as nn

from diffusers import DiffusionPipeline
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL

from distvae.modules.adapters.vae.decoder_adapters import DecoderAdapter
from xfuser.config.config import (
    EngineConfig,
    InputConfig,
)
from xfuser.core.distributed.parallel_state import get_tensor_model_parallel_world_size
from xfuser.logger import init_logger
from xfuser.core.distributed import (
    get_data_parallel_world_size,
    get_sequence_parallel_world_size,
    get_pipeline_parallel_world_size,
    get_classifier_free_guidance_world_size,
    get_classifier_free_guidance_rank,
    is_pipeline_first_stage,
    is_pipeline_last_stage,
    get_pp_group,
    get_world_group,
    get_runtime_state,
    initialize_runtime_state,
    is_dp_last_group,
)
from xfuser.core.fast_attention import (
    get_fast_attn_enable,
    initialize_fast_attn_state,
    fast_attention_compression,
)
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper

from xfuser.envs import PACKAGES_CHECKER

PACKAGES_CHECKER.check_diffusers_version()

from xfuser.model_executor.schedulers import *
from xfuser.model_executor.models.transformers import *
from xfuser.model_executor.layers.attention_processor import *

try:
    import os
    from onediff.infer_compiler import compile as od_compile

    HAS_OF = True
    os.environ["NEXFORT_FUSE_TIMESTEP_EMBEDDING"] = "0"
    os.environ["NEXFORT_FX_FORCE_TRITON_SDPA"] = "1"
except:
    HAS_OF = False

logger = init_logger(__name__)


class xFuserPipelineBaseWrapper(xFuserBaseWrapper, metaclass=ABCMeta):

    def __init__(
        self,
        pipeline: DiffusionPipeline,
        engine_config: EngineConfig,
    ):
        self.module: DiffusionPipeline
        self._init_runtime_state(pipeline=pipeline, engine_config=engine_config)
        self._init_fast_attn_state(pipeline=pipeline, engine_config=engine_config)

        # backbone
        transformer = getattr(pipeline, "transformer", None)
        unet = getattr(pipeline, "unet", None)
        # vae
        vae = getattr(pipeline, "vae", None)
        # scheduler
        scheduler = getattr(pipeline, "scheduler", None)

        if transformer is not None:
            pipeline.transformer = self._convert_transformer_backbone(
                transformer,
                enable_torch_compile=engine_config.runtime_config.use_torch_compile,
                enable_onediff=engine_config.runtime_config.use_onediff,
            )
        elif unet is not None:
            pipeline.unet = self._convert_unet_backbone(unet)

        if scheduler is not None:
            pipeline.scheduler = self._convert_scheduler(scheduler)

        if vae is not None and engine_config.runtime_config.use_parallel_vae and not self.use_naive_forward():
            pipeline.vae = self._convert_vae(vae)

        super().__init__(module=pipeline)

    def reset_activation_cache(self):
        if hasattr(self.module, "transformer") and hasattr(
            self.module.transformer, "reset_activation_cache"
        ):
            self.module.transformer.reset_activation_cache()
        if hasattr(self.module, "unet") and hasattr(
            self.module.unet, "reset_activation_cache"
        ):
            self.module.unet.reset_activation_cache()
        if hasattr(self.module, "vae") and hasattr(
            self.module.vae, "reset_activation_cache"
        ):
            self.module.vae.reset_activation_cache()
        if hasattr(self.module, "scheduler") and hasattr(
            self.module.scheduler, "reset_activation_cache"
        ):
            self.module.scheduler.reset_activation_cache()

    def to(self, *args, **kwargs):
        self.module = self.module.to(*args, **kwargs)
        return self

    @staticmethod
    def enable_fast_attn(func):
        @wraps(func)
        def fast_attn_fn(self, *args, **kwargs):
            if get_fast_attn_enable():
                for block in self.module.transformer.transformer_blocks:
                    for layer in block.children():
                        if isinstance(layer, xFuserAttentionBaseWrapper):
                            layer.stepi = 0
                            layer.cached_residual = None
                            layer.cached_output = None
                out = func(self, *args, **kwargs)
                for block in self.module.transformer.transformer_blocks:
                    for layer in block.children():
                        if isinstance(layer, xFuserAttentionBaseWrapper):
                            layer.stepi = 0
                            layer.cached_residual = None
                            layer.cached_output = None
                return out
            else:
                return func(self, *args, **kwargs)

        return fast_attn_fn

    @staticmethod
    def enable_data_parallel(func):
        @wraps(func)
        def data_parallel_fn(self, *args, **kwargs):
            prompt = kwargs.get("prompt", None)
            negative_prompt = kwargs.get("negative_prompt", "")
            # dp_degree <= batch_size
            batch_size = len(prompt) if isinstance(prompt, list) else 1
            if batch_size > 1:
                dp_degree = get_runtime_state().parallel_config.dp_degree
                dp_group_rank = get_world_group().rank // (
                    get_world_group().world_size // get_data_parallel_world_size()
                )
                dp_group_batch_size = (batch_size + dp_degree - 1) // dp_degree
                start_batch_idx = dp_group_rank * dp_group_batch_size
                end_batch_idx = min(
                    (dp_group_rank + 1) * dp_group_batch_size, batch_size
                )
                prompt = prompt[start_batch_idx:end_batch_idx]
                if isinstance(negative_prompt, List):
                    negative_prompt = negative_prompt[start_batch_idx:end_batch_idx]
                kwargs["prompt"] = prompt
                if "negative_prompt" in kwargs:
                    kwargs["negative_prompt"] = negative_prompt
            return func(self, *args, **kwargs)

        return data_parallel_fn

    def use_naive_forward(self):
        return (
                get_pipeline_parallel_world_size() == 1
                and get_classifier_free_guidance_world_size() == 1
                and get_sequence_parallel_world_size() == 1
                and get_tensor_model_parallel_world_size() == 1
                and get_fast_attn_enable() == False
            )
        
    @staticmethod
    def check_to_use_naive_forward(func):
        @wraps(func)
        def check_naive_forward_fn(self, *args, **kwargs):
            if self.use_naive_forward():
                return self.module(*args, **kwargs)
            else:
                return func(self, *args, **kwargs)

        return check_naive_forward_fn

    @staticmethod
    def check_model_parallel_state(
        cfg_parallel_available: bool = True,
        sequence_parallel_available: bool = True,
        pipefusion_parallel_available: bool = True,
    ):
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                if (
                    not cfg_parallel_available
                    and get_runtime_state().parallel_config.cfg_degree > 1
                ):
                    raise RuntimeError("CFG parallelism is not supported by the model")
                if (
                    not sequence_parallel_available
                    and get_runtime_state().parallel_config.sp_degree > 1
                ):
                    raise RuntimeError(
                        "Sequence parallelism is not supported by the model"
                    )
                if (
                    not pipefusion_parallel_available
                    and get_runtime_state().parallel_config.pp_degree > 1
                ):
                    raise RuntimeError(
                        "Pipefusion parallelism is not supported by the model"
                    )
                return func(*args, **kwargs)

            return wrapper

        return decorator

    def forward(self):
        pass

    def prepare_run(
        self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1
    ):
        if get_fast_attn_enable():
            # set compression methods for DiTFastAttn
            fast_attention_compression(self)

        prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else ""
        warmup_steps = get_runtime_state().runtime_config.warmup_steps
        get_runtime_state().runtime_config.warmup_steps = sync_steps
        self.__call__(
            height=input_config.height,
            width=input_config.width,
            prompt=prompt,
            use_resolution_binning=input_config.use_resolution_binning,
            num_inference_steps=steps,
            generator=torch.Generator(device="cuda").manual_seed(42),
            output_type=input_config.output_type,
        )
        get_runtime_state().runtime_config.warmup_steps = warmup_steps

    def latte_prepare_run(
        self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1
    ):
        prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else ""
        warmup_steps = get_runtime_state().runtime_config.warmup_steps
        get_runtime_state().runtime_config.warmup_steps = sync_steps
        self.__call__(
            height=input_config.height,
            width=input_config.width,
            prompt=prompt,
            # use_resolution_binning=input_config.use_resolution_binning,
            num_inference_steps=steps,
            output_type="latent",
            generator=torch.Generator(device="cuda").manual_seed(42),
        )
        get_runtime_state().runtime_config.warmup_steps = warmup_steps

    def _init_runtime_state(
        self, pipeline: DiffusionPipeline, engine_config: EngineConfig
    ):
        initialize_runtime_state(pipeline=pipeline, engine_config=engine_config)

    def _init_fast_attn_state(
        self, pipeline: DiffusionPipeline, engine_config: EngineConfig
    ):
        initialize_fast_attn_state(pipeline=pipeline, single_config=engine_config.fast_attn_config)

    def _convert_transformer_backbone(
        self, transformer: nn.Module, enable_torch_compile: bool, enable_onediff: bool
    ):
        if (
            get_pipeline_parallel_world_size() == 1
            and get_sequence_parallel_world_size() == 1
            and get_classifier_free_guidance_world_size() == 1
            and get_tensor_model_parallel_world_size() == 1
            and get_fast_attn_enable() == False
        ):
            logger.info(
                "Transformer backbone found, but model parallelism is not enabled, "
                "use naive model"
            )
        else:
            logger.info("Transformer backbone found, paralleling transformer...")
            wrapper = xFuserTransformerWrappersRegister.get_wrapper(transformer)
            transformer = wrapper(transformer)

        if enable_torch_compile and enable_onediff:
            logger.warning(
                f"apply --use_torch_compile and --use_onediff togather. we use torch compile only"
            )

        if enable_torch_compile or enable_onediff:
            if getattr(transformer, "forward") is not None:
                if enable_torch_compile:
                    optimized_transformer_forward = torch.compile(
                        getattr(transformer, "forward")
                    )
                elif enable_onediff:
                    # O3: +fp16 reduction
                    if not HAS_OF:
                        raise RuntimeError(
                            "install onediff and nexfort to --use_onediff"
                        )
                    options = {"mode": "O3"}  # mode can be O2 or O3
                    optimized_transformer_forward = od_compile(
                        getattr(transformer, "forward"),
                        backend="nexfort",
                        options=options,
                    )
                setattr(transformer, "forward", optimized_transformer_forward)
            else:
                raise AttributeError(
                    f"Transformer backbone type: {transformer.__class__.__name__} has no attribute 'forward'"
                )
        return transformer

    def _convert_unet_backbone(
        self,
        unet: nn.Module,
    ):
        logger.info("UNet Backbone found")
        raise NotImplementedError("UNet parallelisation is not supported yet")

    def _convert_scheduler(
        self,
        scheduler: nn.Module,
    ):
        logger.info("Scheduler found, paralleling scheduler...")
        wrapper = xFuserSchedulerWrappersRegister.get_wrapper(scheduler)
        scheduler = wrapper(scheduler)
        return scheduler

    def _convert_vae(
        self,
        vae: AutoencoderKL,
    ):
        logger.info("VAE found, paralleling vae...")
        vae.decoder = DecoderAdapter(vae.decoder)
        return vae

    @abstractmethod
    def __call__(self):
        pass

    def _init_sync_pipeline(self, latents: torch.Tensor):
        get_runtime_state().set_patched_mode(patch_mode=False)

        latents_list = [
            latents[:, :, start_idx:end_idx, :]
            for start_idx, end_idx in get_runtime_state().pp_patches_start_end_idx_global
        ]
        latents = torch.cat(latents_list, dim=-2)
        return latents

    def _init_video_sync_pipeline(self, latents: torch.Tensor):
        get_runtime_state().set_patched_mode(patch_mode=False)
        latents_list = [
            latents[:, :, :, start_idx:end_idx, :]
            for start_idx, end_idx in get_runtime_state().pp_patches_start_end_idx_global
        ]
        latents = torch.cat(latents_list, dim=-2)
        return latents

    def _init_async_pipeline(
        self,
        num_timesteps: int,
        latents: torch.Tensor,
        num_pipeline_warmup_steps: int,
    ):
        get_runtime_state().set_patched_mode(patch_mode=True)

        if is_pipeline_first_stage():
            # get latents computed in warmup stage
            # ignore latents after the last timestep
            latents = (
                get_pp_group().pipeline_recv()
                if num_pipeline_warmup_steps > 0
                else latents
            )
            patch_latents = list(
                latents.split(get_runtime_state().pp_patches_height, dim=2)
            )
        elif is_pipeline_last_stage():
            patch_latents = list(
                latents.split(get_runtime_state().pp_patches_height, dim=2)
            )
        else:
            patch_latents = [
                None for _ in range(get_runtime_state().num_pipeline_patch)
            ]

        recv_timesteps = (
            num_timesteps - 1 if is_pipeline_first_stage() else num_timesteps
        )
        for _ in range(recv_timesteps):
            for patch_idx in range(get_runtime_state().num_pipeline_patch):
                get_pp_group().add_pipeline_recv_task(patch_idx)

        return patch_latents

    def _process_cfg_split_batch(
        self,
        negative_embeds: torch.Tensor,
        embeds: torch.Tensor,
        negative_embdes_mask: torch.Tensor = None,
        embeds_mask: torch.Tensor = None,
    ):
        if get_classifier_free_guidance_world_size() == 1:
            embeds = torch.cat([negative_embeds, embeds], dim=0)
        elif get_classifier_free_guidance_rank() == 0:
            embeds = negative_embeds
        elif get_classifier_free_guidance_rank() == 1:
            embeds = embeds
        else:
            raise ValueError("Invalid classifier free guidance rank")

        if negative_embdes_mask is None:
            return embeds

        if get_classifier_free_guidance_world_size() == 1:
            embeds_mask = torch.cat([negative_embdes_mask, embeds_mask], dim=0)
        elif get_classifier_free_guidance_rank() == 0:
            embeds_mask = negative_embdes_mask
        elif get_classifier_free_guidance_rank() == 1:
            embeds_mask = embeds_mask
        else:
            raise ValueError("Invalid classifier free guidance rank")
        return embeds, embeds_mask

    def is_dp_last_group(self):
        """Return True if in the last data parallel group, False otherwise.
        Also include parallel vae situation.
        """
        if get_runtime_state().runtime_config.use_parallel_vae and not self.use_naive_forward():
            return get_world_group().rank == 0
        else:
            return is_dp_last_group()
    
    def gather_broadcast_latents(self, latents:torch.Tensor):
        """gather latents from dp last group and broacast final latents
        """
        
        # ---------gather latents from dp last group-----------
        rank = get_world_group().rank
        device = f"cuda:{rank}"

        # all gather dp last group rank list
        dp_rank_list = [torch.zeros(1, dtype=int, device=device) for _ in range(get_world_group().world_size)]
        if is_dp_last_group():
            gather_rank = int(rank)
        else:
            gather_rank = -1
        torch.distributed.all_gather(dp_rank_list, torch.tensor([gather_rank],dtype=int,device=device))
        
        dp_rank_list = [int(dp_rank[0]) for dp_rank in dp_rank_list if int(dp_rank[0])!=-1]
        dp_last_group = torch.distributed.new_group(dp_rank_list)

        # gather latents from dp last group
        if rank == dp_rank_list[-1]:
            latents_list = [torch.zeros_like(latents) for _ in dp_rank_list]
        else:
            latents_list = None
        if rank in dp_rank_list:
            torch.distributed.gather(latents, latents_list, dst=dp_rank_list[-1], group=dp_last_group)

        if rank == dp_rank_list[-1]:
            latents = torch.cat(latents_list,dim=0)
        
        # ------broadcast latents to all nodes---------
        src = dp_rank_list[-1]
        latents_shape_len = torch.zeros(1,dtype=torch.int,device=device)
        
        # broadcast latents shape len
        if rank == src:
            latents_shape_len[0] = len(latents.shape)
        get_world_group().broadcast(latents_shape_len,src=src)
        
        # broadcast latents shape
        if rank == src:
            input_shape = torch.tensor(latents.shape,dtype=torch.int,device=device)
        else:
            input_shape = torch.zeros(latents_shape_len[0],dtype=torch.int,device=device)
        get_world_group().broadcast(input_shape,src=src)
        
        # broadcast latents
        if rank != src:
            dtype = get_runtime_state().runtime_config.dtype
            latents = torch.zeros(torch.Size(input_shape),dtype=dtype,device=device)
        get_world_group().broadcast(latents,src=src)

        return latents
